Time series clustering is to partition time series data into groups based on similarity or distance, so that time series in the same cluster are similar.
Methodology followed:
from vrae.vrae import VRAE
from vrae.utils import *
from vrae.utils_EMG import *
import numpy as np
import torch
import pickle
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.metrics import mean_squared_error as mse
import plotly
from torch.utils.data import DataLoader, TensorDataset
plotly.offline.init_notebook_mode()
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
dload = './model_dir'
seq_len = 10
hidden_size = 256
hidden_layer_depth = 3
latent_length = 16
batch_size = 32
learning_rate = 0.00002
n_epochs = 1500
dropout_rate = 0.0
optimizer = 'Adam' # options: ADAM, SGD
cuda = True # options: True, False
print_every=10
clip = True # options: True, False
max_grad_norm=5
loss = 'MSELoss' # options: SmoothL1Loss, MSELoss
block = 'LSTM' # options: LSTM, GRU
output = False
training_file = ['20201020_Pop_Cage_001','20201020_Pop_Cage_002','20201020_Pop_Cage_003','20201020_Pop_Cage_004',
'20201020_Pop_Cage_006']
X_train, y_train, X_train_ori, X_pca = load_data(direc = 'data', dataset="EMG", all_file = training_file,
do_pca = True, single_channel = None,
batch_size = batch_size, seq_len = seq_len, pca_component = 6)
train_dataset = TensorDataset(torch.from_numpy(X_train))
Loading 20201020_Pop_Cage_001, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3.] Loading 20201020_Pop_Cage_002, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3.] Loading 20201020_Pop_Cage_003, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3. 4.] Loading 20201020_Pop_Cage_004, X shape (3601, 150, 1), y shape (3601, 1), has label [-1. 0. 1. 2. 3. 4.] Loading 20201020_Pop_Cage_006, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3. 4.] Doing PCA Explained variance ratio: [0.7165347 0.82274211 0.88277787 0.92401944 0.9436966 0.95823764 0.97067536 0.9778943 0.98333348 0.98764237 0.99133662 0.99435286 0.99654689 0.99843935 1. ] Dataset shape: (17984, 10, 6) Label: [-1. 0. 1. 2. 3. 4.], shape: (17984, 1)
num_features = X_train.shape[2]
VRAE inherits from sklearn.base.BaseEstimator and overrides fit, transform and fit_transform functions, similar to sklearn modules
from vrae.vrae import VRAE
vrae = VRAE(sequence_length=seq_len,
number_of_features = num_features,
hidden_size = hidden_size,
hidden_layer_depth = hidden_layer_depth,
latent_length = latent_length,
batch_size = batch_size,
learning_rate = learning_rate,
n_epochs = n_epochs,
dropout_rate = dropout_rate,
optimizer = optimizer,
cuda = cuda,
print_every=print_every,
clip=clip,
max_grad_norm=max_grad_norm,
loss = loss,
block = block,
dload = dload,
output = output)
/home/roton2/miniconda3/envs/emg/lib/python3.9/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
#vrae.fit(train_dataset)
#If the model has to be saved, with the learnt parameters use:
vrae.fit(train_dataset)
Epoch: 9 Average loss: 3634387.2977 Epoch: 19 Average loss: 2966009.5853 Epoch: 29 Average loss: 2410859.4656 Epoch: 39 Average loss: 1982791.6093 Epoch: 49 Average loss: 1654083.0303 Epoch: 59 Average loss: 1375509.2561 Epoch: 69 Average loss: 1150791.1264 Epoch: 79 Average loss: 972029.7305 Epoch: 89 Average loss: 829104.0198 Epoch: 99 Average loss: 713985.4167 Epoch: 109 Average loss: 621160.7740 Epoch: 119 Average loss: 548229.8153 Epoch: 129 Average loss: 492340.8473 Epoch: 139 Average loss: 451099.0329 Epoch: 149 Average loss: 419895.6965 Epoch: 159 Average loss: 395512.1548 Epoch: 169 Average loss: 374876.8779 Epoch: 179 Average loss: 357745.4435 Epoch: 189 Average loss: 343754.3197 Epoch: 199 Average loss: 331559.3227 Epoch: 209 Average loss: 321234.5232 Epoch: 219 Average loss: 311318.9785 Epoch: 229 Average loss: 302577.3161 Epoch: 239 Average loss: 294091.0862 Epoch: 249 Average loss: 286102.1790 Epoch: 259 Average loss: 278404.2058 Epoch: 269 Average loss: 271104.9191 Epoch: 279 Average loss: 264400.1280 Epoch: 289 Average loss: 257962.8230 Epoch: 299 Average loss: 251852.2127 Epoch: 309 Average loss: 246217.2208 Epoch: 319 Average loss: 240764.5650 Epoch: 329 Average loss: 235705.0508 Epoch: 339 Average loss: 230917.4367 Epoch: 349 Average loss: 226203.9449 Epoch: 359 Average loss: 221561.3546 Epoch: 369 Average loss: 217339.4141 Epoch: 379 Average loss: 213117.3538 Epoch: 389 Average loss: 209043.6240 Epoch: 399 Average loss: 205221.1091 Epoch: 409 Average loss: 201288.7088 Epoch: 419 Average loss: 197548.6801 Epoch: 429 Average loss: 194003.1526 Epoch: 439 Average loss: 190832.1399 Epoch: 449 Average loss: 187385.7578 Epoch: 459 Average loss: 184127.5365 Epoch: 469 Average loss: 181095.9934 Epoch: 479 Average loss: 178018.8119 Epoch: 489 Average loss: 175150.2335 Epoch: 499 Average loss: 172281.0849 Epoch: 509 Average loss: 169664.7443 Epoch: 519 Average loss: 167098.5611 Epoch: 529 Average loss: 164506.7316 Epoch: 539 Average loss: 162115.1963 Epoch: 549 Average loss: 159816.2540 Epoch: 559 Average loss: 157310.4100 Epoch: 569 Average loss: 155139.8493 Epoch: 579 Average loss: 152842.2910 Epoch: 589 Average loss: 150823.1635 Epoch: 599 Average loss: 148628.5764 Epoch: 609 Average loss: 146911.3845 Epoch: 619 Average loss: 145132.3674 Epoch: 629 Average loss: 142970.6460 Epoch: 639 Average loss: 141073.8679 Epoch: 649 Average loss: 139429.6298 Epoch: 659 Average loss: 137628.9156 Epoch: 669 Average loss: 136002.7884 Epoch: 679 Average loss: 134308.2765 Epoch: 689 Average loss: 132612.2161 Epoch: 699 Average loss: 131169.7916 Epoch: 709 Average loss: 129720.8882 Epoch: 719 Average loss: 128165.1993 Epoch: 729 Average loss: 126788.3894 Epoch: 739 Average loss: 125423.7020 Epoch: 749 Average loss: 123930.8644 Epoch: 759 Average loss: 122522.0331 Epoch: 769 Average loss: 121161.4190 Epoch: 779 Average loss: 119760.0294 Epoch: 789 Average loss: 118519.1575 Epoch: 799 Average loss: 117388.3360 Epoch: 809 Average loss: 116163.6477 Epoch: 819 Average loss: 114842.7103 Epoch: 829 Average loss: 113746.0991 Epoch: 839 Average loss: 112512.9520 Epoch: 849 Average loss: 111331.6711 Epoch: 859 Average loss: 110351.1482 Epoch: 869 Average loss: 109141.7587 Epoch: 879 Average loss: 108219.4949 Epoch: 889 Average loss: 107126.0121 Epoch: 899 Average loss: 106056.9836 Epoch: 909 Average loss: 104960.9662 Epoch: 919 Average loss: 103999.8641 Epoch: 929 Average loss: 103060.1471 Epoch: 939 Average loss: 102093.8351 Epoch: 949 Average loss: 101109.3346 Epoch: 959 Average loss: 100206.4063 Epoch: 969 Average loss: 99298.7870 Epoch: 979 Average loss: 98346.2856 Epoch: 989 Average loss: 97717.6625 Epoch: 999 Average loss: 96589.4824 Epoch: 1009 Average loss: 95758.1171 Epoch: 1019 Average loss: 94848.9770 Epoch: 1029 Average loss: 94060.9992 Epoch: 1039 Average loss: 93291.7144 Epoch: 1049 Average loss: 92344.7284 Epoch: 1059 Average loss: 91615.5512 Epoch: 1069 Average loss: 90949.3504 Epoch: 1079 Average loss: 90008.5815 Epoch: 1089 Average loss: 89317.9669 Epoch: 1099 Average loss: 88601.0169 Epoch: 1109 Average loss: 87711.8639 Epoch: 1119 Average loss: 86956.0451 Epoch: 1129 Average loss: 86277.1733 Epoch: 1139 Average loss: 85617.6316 Epoch: 1149 Average loss: 84869.5121 Epoch: 1159 Average loss: 83982.2331 Epoch: 1169 Average loss: 83456.1521 Epoch: 1179 Average loss: 82695.9695 Epoch: 1189 Average loss: 81984.1254 Epoch: 1199 Average loss: 81377.4432 Epoch: 1209 Average loss: 80720.1541 Epoch: 1219 Average loss: 80056.4134 Epoch: 1229 Average loss: 79443.1730 Epoch: 1239 Average loss: 78759.9036 Epoch: 1249 Average loss: 78118.9651 Epoch: 1259 Average loss: 77560.3848 Epoch: 1269 Average loss: 76778.0212 Epoch: 1279 Average loss: 76357.1698 Epoch: 1289 Average loss: 75736.9064 Epoch: 1299 Average loss: 75170.0282 Epoch: 1309 Average loss: 74560.1728 Epoch: 1319 Average loss: 73999.0324 Epoch: 1329 Average loss: 73280.2593 Epoch: 1339 Average loss: 72806.8697 Epoch: 1349 Average loss: 72362.2852 Epoch: 1359 Average loss: 71621.7010 Epoch: 1369 Average loss: 71149.5396 Epoch: 1379 Average loss: 70643.1954 Epoch: 1389 Average loss: 70215.7900 Epoch: 1399 Average loss: 69461.5817 Epoch: 1409 Average loss: 69155.3725 Epoch: 1419 Average loss: 68430.8103 Epoch: 1429 Average loss: 68116.5750 Epoch: 1439 Average loss: 67541.9921 Epoch: 1449 Average loss: 66968.0486 Epoch: 1459 Average loss: 66515.6647 Epoch: 1469 Average loss: 66040.1094 Epoch: 1479 Average loss: 65484.1307 Epoch: 1489 Average loss: 65179.1381 Epoch: 1499 Average loss: 64604.0790
plt.plot(vrae.all_loss)
[<matplotlib.lines.Line2D at 0x7f0426ecf5e0>]
plt.plot(vrae.rec_mse)
[<matplotlib.lines.Line2D at 0x7f04264ce5e0>]
#If the latent vectors have to be saved, pass the parameter `save`
z_run = vrae.transform(train_dataset, save = True, filename = 'z_run_e57_b32_z16_pca.pkl')
z_run.shape
(17984, 16)
vrae.save('./vrae_e57_b32_z16_pca.pth')
vrae.load(dload+'/vrae_e5_3000epoch.pth')
with open(dload+'/z_run_e57pca_2000epoch.pkl', 'rb') as fh:
z_run = pickle.load(fh)
reconstruction = recon(vrae, X_train)
plot_recon_feature(X_train, reconstruction, idx = None)
_, _, _ = plot_recon_metrics(X_train, reconstruction, x_lim = [2000, 4000])
Channel 1, corr = 0.9362, mse = 23.780959, mean = -0.0000. Channel 2, corr = 0.8773, mse = 26.516869, mean = 0.0000. Channel 3, corr = 0.7950, mse = 33.826289, mean = 0.0000. Channel 4, corr = 0.8020, mse = 34.875095, mean = -0.0000. Channel 5, corr = 0.6730, mse = 38.746768, mean = 0.0000. Channel 6, corr = 0.6078, mse = 42.494420, mean = -0.0000.
recon_channel = pca_inverse(X_pca, reconstruction)
plot_recon_feature(X_train_ori, recon_channel, idx = None)
_, _, _ = plot_recon_metrics(X_train_ori, recon_channel, x_lim = [0, 2000])
Channel 1, corr = 0.6976, mse = 59.334525, mean = 29.5886. Channel 2, corr = 0.6834, mse = 44.348163, mean = 27.4895. Channel 3, corr = 0.6004, mse = 83.301298, mean = 31.6063. Channel 4, corr = 0.5489, mse = 42.786141, mean = 19.6259. Channel 5, corr = 0.5853, mse = 31.435098, mean = 13.4139. Channel 6, corr = 0.6485, mse = 40.583734, mean = 32.0427. Channel 7, corr = 0.8230, mse = 67.725698, mean = 49.2383. Channel 8, corr = 0.8175, mse = 78.008913, mean = 54.5515. Channel 9, corr = 0.6623, mse = 34.414231, mean = 21.3511. Channel 10, corr = 0.6938, mse = 48.115765, mean = 30.8874. Channel 11, corr = 0.8155, mse = 25.464214, mean = 46.5397. Channel 12, corr = 0.5068, mse = 129.173548, mean = 21.5676. Channel 13, corr = 0.8431, mse = 104.643200, mean = 50.0767. Channel 14, corr = 0.7858, mse = 65.437832, mean = 39.4550. Channel 15, corr = 0.7587, mse = 58.439501, mean = 36.4381.
testing_file = ['20201020_Pop_Cage_005', '20201020_Pop_Cage_007']
X_test, y_test, X_test_ori, test_pca = load_data(direc = 'data', dataset="EMG", all_file = testing_file,
do_pca = True, single_channel = None,
batch_size = batch_size, seq_len = seq_len, pca_component = 6)
Loading 20201020_Pop_Cage_005, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3. 5.] Loading 20201020_Pop_Cage_007, X shape (3599, 150, 1), y shape (3599, 1), has label [-1. 0. 1. 2. 3. 4.] Doing PCA Explained variance ratio: [0.72963202 0.82936616 0.88681766 0.92899866 0.9469495 0.9610898 0.97256296 0.97950278 0.98456741 0.98865807 0.99216837 0.99489398 0.99686689 0.99867192 1. ] Dataset shape: (7168, 10, 6) Label: [-1. 0. 1. 2. 3. 4. 5.], shape: (7168, 1)
# Uncomment if using pca
recon_test = recon(vrae, X_test)
recon_channel_test = pca_inverse(test_pca, recon_test)
#plot_recon_feature(X_test, recon_test, idx = None)
plot_recon_feature(X_test_ori, recon_channel_test, idx = None)
# corr_mean, mse_mean, mean_mean = plot_recon_metrics(X_test, recon_test, x_lim = [0, 2000])
corr_mean, mse_mean, mean_mean = plot_recon_metrics(X_test_ori, recon_channel_test, x_lim = [0, 2000])
Channel 1, corr = 0.6165, mse = 99.489509, mean = 29.5584. Channel 2, corr = 0.6033, mse = 70.222180, mean = 27.4149. Channel 3, corr = 0.4859, mse = 127.250266, mean = 32.5453. Channel 4, corr = 0.4505, mse = 53.838572, mean = 19.5454. Channel 5, corr = 0.4766, mse = 40.331625, mean = 13.0227. Channel 6, corr = 0.4050, mse = 205.489407, mean = 30.9896. Channel 7, corr = 0.6929, mse = 231.395477, mean = 50.5355. Channel 8, corr = 0.7114, mse = 201.009913, mean = 55.5833. Channel 9, corr = 0.5750, mse = 51.893745, mean = 20.8698. Channel 10, corr = 0.5659, mse = 124.627128, mean = 31.7638. Channel 11, corr = 0.6147, mse = 481.011196, mean = 46.6334. Channel 12, corr = 0.4061, mse = 155.956307, mean = 21.2069. Channel 13, corr = 0.7331, mse = 247.815192, mean = 52.4404. Channel 14, corr = 0.6575, mse = 194.315561, mean = 41.4571. Channel 15, corr = 0.6278, mse = 184.912212, mean = 38.8859.
print(list(corr_mean))
print(list(mse_mean))
print(list(mean_mean))
[0.6165009849157875, 0.6033037277269683, 0.48588493963679075, 0.4504800846552252, 0.47656022065930687, 0.4049794121687208, 0.692857358052444, 0.7113978078404368, 0.5749525270351056, 0.5658823514683109, 0.6146899621996763, 0.40611010065547354, 0.7330519196906536, 0.6575181347036434, 0.627820533007272] [99.48950861730694, 70.22217991722833, 127.25026620373713, 53.838571652820086, 40.331625399595445, 205.48940697351048, 231.39547718837096, 201.00991255225534, 51.89374520602131, 124.62712780272132, 481.01119647880796, 155.95630695184673, 247.81519220076794, 194.31556089468455, 184.91221179502304] [29.558443014706768, 27.414924678304136, 32.54528668084857, 19.545427740051515, 13.022679916323572, 30.98961445901738, 50.535490176304165, 55.58334129172573, 20.86980557166033, 31.763762002375774, 46.63335753107305, 21.206949619877204, 52.44043212360176, 41.45713725737416, 38.88586835658078]
bhvs = {'crawling': np.array([0]),
'high picking treats': np.array([1]),
'low picking treats': np.array([2]),
'pg': np.array([3]),
'sitting still': np.array([4]),
'grooming': np.array([5]),
'no_behavior': np.array([-1])}
inv_bhvs = {int(v): k for k, v in bhvs.items()}
test_dataset = TensorDataset(torch.from_numpy(X_test))
z_run_test = vrae.transform(test_dataset, save = False)
z_run_all = np.vstack((z_run, z_run_test))
y_all = np.vstack((y_train, y_test))
visualize(z_run = z_run_all, y = y_all, inv_bhvs = inv_bhvs, one_in = 4)